{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9cdb03",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Stock imports\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from lee_et_al_2023.src import analysis\n",
    "from lee_et_al_2023.src import base\n",
    "from lee_et_al_2023.src import data_loaders\n",
    "base.set_visual_settings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce4130f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "models, humans, panel, subjects = data_loaders.get_clean()\n",
    "models['GNN (shuffled baseline)'] = analysis.shuffle_df(models['GNN'], shuffle='molecules')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efa9540f",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "ott = analysis.fast_process(humans, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "755f8e64",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_name = '2E'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cf25192",
   "metadata": {},
   "outputs": [],
   "source": [
    "colormap = {\n",
    "    'RF': '#8da0cb',\n",
    "    'GNN': '#fc8d62',\n",
    "    'Human': '#66c2a5',\n",
    "    'GNN (shuffled baseline)': 'grey',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c21f94e",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 6))\n",
    "for group, color in colormap.items():\n",
    "    plot = sns.ecdfplot(ott[group], color=color, stat='proportion', label=group)\n",
    "#plt.axvline(x=0, color='black', linestyle='--', linewidth=2)\n",
    "plt.axhline(y=0.5, color='black', linestyle='--', linewidth=2)\n",
    "plt.xlabel('Correlation to Panel Mean', fontsize=20)\n",
    "plt.ylabel('Proportion of molecules', fontsize=20)\n",
    "plt.xlim(-0.25, 0.75)\n",
    "plt.xticks(np.linspace(-0.25, 0.75, 5))\n",
    "\n",
    "plt.legend()\n",
    "for axis in ['bottom','left']:\n",
    "    plot.spines[axis].set_linewidth(3)\n",
    "    plot.spines[axis].set_edgecolor('black')\n",
    "for axis in ['top','right']:\n",
    "    plot.spines[axis].set_visible(False)\n",
    "\n",
    "plot.grid(False)\n",
    "\n",
    "for group in ('Human', 'GNN', 'RF', 'GNN (shuffled baseline)'):\n",
    "    plt.axvline(x=ott[group].median(), linestyle= ':', linewidth=1, color='black')\n",
    "    if group in ('Human', 'GNN (shuffled baseline)'): continue\n",
    "    plt.arrow(ott['Human'].median(), 0.5 + ott[group].median(),\n",
    "              dx=ott[group].median() - ott['Human'].median(),\n",
    "              dy=0,\n",
    "             width=0.03,\n",
    "             length_includes_head=True,\n",
    "             head_width=0.03,\n",
    "             head_length=0.01,\n",
    "             color=colormap[group])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80903989",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "auto:light,ipynb",
   "notebook_metadata_filter": "-all"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
